syntax = "proto3";

package tensorflow.eager;

import "tensorflow/core/framework/attr_value.proto";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/function.proto";
import "tensorflow/core/framework/versions.proto";
import "tensorflow/core/protobuf/tensorflow_server.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/tensor.proto";

message RemoteTensorHandle {
  // The ID of the operation that produced this tensor.
  int64 op_id = 1;
  // The index into the outputs of the operation that produced this tensor.
  int32 output_num = 2;
}

// A proto representation of an eager operation.
message Operation {
  // A unique identifier for the operation. Set by the client so that the client
  // can uniquely identify the outputs of the scheduled operation.
  //
  // In the initial implementation, sending duplicate IDs has undefined
  // behaviour, but additional constraints may be placed upon this in the
  // future.
  int64 id = 1;
  string name = 2;
  repeated RemoteTensorHandle inputs = 3;

  // Control Operation IDs that will be respected when ops are re-ordered by
  // async execution. If async execution (+ op re-ordering) is not enabled, this
  // should have no effect.
  repeated int64 control_op_ids = 4;
  map<string, AttrValue> attrs = 5;
  string device = 6;
}

message QueueItem {
  // The remote executor should be able to handle either executing ops directly,
  // or releasing any unused tensor handles, since the tensor lifetime is
  // maintained by the client.
  oneof item {
    RemoteTensorHandle handle_to_decref = 1;
    Operation operation = 2;
  }
}

message QueueResponse {
  repeated TensorShapeProto shape = 1;
}

message CreateContextRequest {
  // Identifies the full cluster, and this particular worker's position within.
  ServerDef server_def = 1;

  // Whether the ops on the worker should be executed synchronously or
  // asynchronously. By default, ops are executed synchronously.
  bool async = 2;

  // Number of seconds to keep the context alive. If more than keep_alive_secs
  // has passed since a particular context has been communicated with, it will
  // be garbage collected.
  int64 keep_alive_secs = 3;

  // This is the version for all the ops that will be enqueued by the client.
  VersionDef version_def = 4;

  // This ID will be used for all future communications. It is essential that
  // both ends use this ID for selecting a rendezvous to get everything to
  // match.
  int64 rendezvous_id = 5;
}

message CreateContextResponse {
  // The ID of the created context. This is usually a randomly generated number,
  // that will be used to identify the context in future requests to the
  // service. Contexts are not persisted through server restarts.
  fixed64 context_id = 1;

  // List of devices that are locally accessible to the worker.
  repeated DeviceAttributes device_attributes = 2;
}

message EnqueueRequest {
  fixed64 context_id = 1;

  repeated QueueItem queue = 3;
}

message EnqueueResponse {
  // A single operation response for every item in the request.
  repeated QueueResponse queue_response = 1;
}

message WaitQueueDoneRequest {
  fixed64 context_id = 1;

  // Ids to wait on. If empty, wait on everything currently pending.
  repeated int64 op_id = 2;
}

message WaitQueueDoneResponse {
  // TODO(nareshmodi): Consider adding NodeExecStats here to be able to
  // propagate some stats.
}

message KeepAliveRequest {
  fixed64 context_id = 1;
}

message KeepAliveResponse {
}

message CloseContextRequest {
  fixed64 context_id = 1;
}

message CloseContextResponse {
}

message RegisterFunctionRequest {
  fixed64 context_id = 1;

  FunctionDef function_def = 2;
}

message RegisterFunctionResponse {
}

message SendTensorRequest {
  fixed64 context_id = 1;

  // All remote tensors are identified by <Op ID, Output num>. To mimic this
  // situation when directly sending tensors, we include an "artificial" op ID
  // (which would have corresponded to the _Recv op when not using SendTensor).
  int64 op_id = 2;
  // The index within the repeated field is the output number that will help
  // uniquely identify (along with the above op_id) the particular tensor.
  repeated TensorProto tensors = 3;

  // The device on which the tensors should be resident.
  string device_name = 4;
}

message SendTensorResponse {
}

////////////////////////////////////////////////////////////////////////////////
//
// Eager Service defines a TensorFlow service that executes operations eagerly
// on a set of local devices, on behalf of a remote Eager executor.
//
// The service impl will keep track of the various clients and devices it has
// access to and allows the client to enqueue ops on any devices that it is able
// to access and schedule data transfers from/to any of the peers.
//
// A client can generate multiple contexts to be able to independently execute
// operations, but cannot share data between the two contexts.
//
// NOTE: Even though contexts generated by clients should be independent, the
// lower level tensorflow execution engine is not, so they might share some data
// (e.g. a Device's ResourceMgr).
//
////////////////////////////////////////////////////////////////////////////////
service EagerService {
  // This initializes the worker, informing it about the other workers in the
  // cluster and exchanging authentication tokens which will be used in all
  // other RPCs to detect whether the worker has restarted.
  rpc CreateContext(CreateContextRequest) returns (CreateContextResponse);

  // This takes a list of Execute and DeleteTensorHandle operations and enqueues
  // (in async mode) or executes (in sync mode) them on the remote server.
  // All outputs of ops which were not explicitly deleted with
  // DeleteTensorHandle entries will be assumed to be alive and are usable by
  // future calls to Enqueue.
  rpc Enqueue(EnqueueRequest) returns (EnqueueResponse);

  // Takes a set of op IDs and waits until those ops are done. Returns any error
  // in the stream so far.
  rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);

  // Contexts are always created with a deadline and no RPCs within a deadline
  // will trigger a context garbage collection. KeepAlive calls can be used to
  // delay this.
  rpc KeepAlive(KeepAliveRequest) returns (KeepAliveResponse);

  // Closes the context. No calls to other methods using the existing context ID
  // are valid after this.
  rpc CloseContext(CloseContextRequest) returns (CloseContextResponse);

  // Takes a FunctionDef and makes it enqueable on the remote worker.
  rpc RegisterFunction(RegisterFunctionRequest)
      returns (RegisterFunctionResponse);

  // An RPC to push tensors to the server. At times, certain environments don't
  // allow the server to connect back to the client.
  rpc SendTensor(SendTensorRequest) returns (SendTensorResponse);
}
